Skip to content

Fuse gate/up expert projections in SwitchGLU#1032

Open
Thump604 wants to merge 3 commits intoml-explore:mainfrom
Thump604:feat/fused-gate-up-projection
Open

Fuse gate/up expert projections in SwitchGLU#1032
Thump604 wants to merge 3 commits intoml-explore:mainfrom
Thump604:feat/fused-gate-up-projection

Conversation

@Thump604
Copy link
Copy Markdown

@Thump604 Thump604 commented Mar 21, 2026

Summary

Add fuse_gate_up option to SwitchGLU that performs a single gather_qmm call with concatenated gate+up weights instead of two separate calls. Eliminates one kernel dispatch per MoE layer per token.

Follows the approach proposed in #956 by @BurntToastGPT: handle fusion at the model layer via sanitize.

Measured Results (from #956)

Model Improvement
Qwen3-30B-A3B +8.6%
Qwen3.5-122B-A10B +5.0%
MiniMax M2.5 (456B) +5.1%
GPT-OSS-120B +5.0%
OLMoE-1B-7B +3.8%

Changes (10 files, 6 model families)

Core:

  • switch_layers.py: SwitchGLU gets fuse_gate_up=False parameter. When True, creates single gate_up_proj SwitchLinear (2x hidden_dims). Forward pass auto-detects via "gate_up_proj" in self.

Models with already-fused checkpoint weights (stop splitting):

  • qwen3_5_moe.py, qwen3_vl_moe.py: Sanitize keeps fused weights as gate_up_proj.weight
  • llama4.py: Same pattern (contiguous split + swapaxes)

Models with per-expert weights (stack + concatenate):

  • olmoe.py: Stack per-expert gate/up, concatenate into gate_up_proj (handles quantized weights/scales/biases)
  • mixtral.py: Same pattern with w1/w2/w3 naming
  • minimax.py: Same pattern (FP8 dequant then fuse)

Constructor updates:

  • qwen3_next.py, qwen3_moe.py: Pass fuse_gate_up=True

Sharding:

  • qwen3_5.py: Handles both fused and unfused paths

Backward compatible — fuse_gate_up=False (default) preserves existing behavior for all other models using SwitchGLU. GPT-OSS excluded (interleaved weights need separate handling).

Test plan

  • Verify token-exact output matches unfused path on Qwen3.5-122B
  • Benchmark tok/s before/after on M2 Ultra
  • Verify quantized models load correctly with fused weights
  • Verify Mixtral, OLMoE, MiniMax load and generate correctly

Add fuse_gate_up option to SwitchGLU that uses a single gather_qmm
call with 2x hidden_dims instead of two separate calls for gate_proj
and up_proj. Eliminates one kernel dispatch per MoE layer per token.

Measured +5% tok/s on Qwen3.5-122B, +8.6% on Qwen3-30B,
+5.1% on MiniMax M2.5, +3.8% on OLMoE (ref: ml-explore#956).

Models updated: Qwen3, Qwen3.5 (all variants), Llama 4,
Mixtral, MiniMax, OLMoE.
@Thump604 Thump604 force-pushed the feat/fused-gate-up-projection branch from 4ccf76f to 81270f9 Compare March 21, 2026 03:21
@Thump604
Copy link
Copy Markdown
Author

Update: Metal OOM on memory-constrained setup

Testing on M2 Ultra 128GB with Qwen3.5-122B-A10B (5-bit, ~82GB weights), the fused gather_qmm with 2x output dimension causes kIOGPUCommandBufferCallbackErrorOutOfMemory on the first inference request. Reverting to unfused (two separate gather_qmm calls) works fine.

The 122B model leaves only ~46GB headroom for KV cache + Metal scratch + OS. The single larger gather_qmm likely requires more intermediate buffer space than two smaller calls combined.

The original benchmarks in #956 were on M3 Ultra 512GB where scratch memory isn't a constraint. This suggests the optimization needs a memory-aware fallback — perhaps fuse only when headroom is sufficient, or cap the fused dimension.

Will investigate whether MLX_METAL_CACHE_LIMIT or intermediate buffer sizing is the bottleneck.

Read fuse_gate_up from model config (default True) instead of
hardcoding. On memory-constrained systems (e.g. 122B on 128GB),
set "fuse_gate_up": false in config.json to use separate gate/up
projections that allow Metal to deallocate scratch between kernels.

Also fixes sanitize() in qwen3_next and qwen3_moe which produced
unfused weight names while the model expected fused — a mismatch
that would break loading from HuggingFace format.
@Thump604
Copy link
Copy Markdown
Author

Update: OOM fix implemented

Added config-driven fuse_gate_up (commit 801ba8c):

  • fuse_gate_up is now read from model config via getattr(args, "fuse_gate_up", True) — defaults to True for the +5% speedup
  • Set "fuse_gate_up": false in config.json for memory-constrained setups (e.g. 122B on M2 Ultra 128GB) to use separate gate/up projections
  • The unfused path allows Metal to deallocate scratch memory between the two gather_qmm kernels
  • Also fixed a bug: qwen3_next.py and qwen3_moe.py sanitize produced unfused weight names while the model init expected fused — this would break loading from HuggingFace format

All 10 model files now support both paths, controlled by a single config field.

@kernelpool
Copy link
Copy Markdown
Contributor

I tried this on some of my existing MiniMax DWQ quants (e.g. catalystsec/MiniMax-M2.5-4bit-DWQ) but unfortunately they no longer work with these changes, so you may want to look at the impact on existing (quantized) models.

Existing quantized models (e.g. MiniMax DWQ quants) have separate
gate_proj/up_proj weights. Defaulting to True broke them because
the model init creates gate_up_proj while weights have unfused names,
and quantized weights cannot be concatenated (scales/biases mismatch).

Default to False so existing models work unchanged. New conversions
can opt in by setting "fuse_gate_up": true in config.json.
@Thump604
Copy link
Copy Markdown
Author

Thanks @kernelpool for the report! Fixed in e7de34b.

The issue was that fuse_gate_up defaulted to True, which changed the expected weight names from gate_proj/up_proj to gate_up_proj. Existing quantized models have the unfused format, and quantized weights can't be concatenated (scales/biases are per-group).

Fix: Default is now False — existing models work unchanged. To opt in to the +5% speedup, set "fuse_gate_up": true in your model's config.json (should be done at quantize/convert time for new models).

@Thump604
Copy link
Copy Markdown
Author

@angeloskath @awni — this PR has been open since March 20 with no maintainer review. A community member (kernelpool) found a breaking-change issue with existing quantized models, which we fixed the same day — default is now False (opt-in only, zero impact on existing models).

Is there a concern with the approach? Happy to rework if needed. The fused path gives ~5% speedup on SwitchGLU MoE models for users who opt in via config.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants